import numpy as np
import mujoco_py
import gym

DEFAULT_CAMERA_CONFIG = {
    'trackbodyid': 2,
    'distance': 4.0,
    'lookat': np.array((0.0, 0.0, 1.15)),
    'elevation': -20.0,
}


class Env(object):
    def __init__(self, seed=None, env_name='Walker2d-v3'):
        self.env_name = env_name
        self.env = gym.make(env_name)  #time_limit
        self.env_ = self.env.env  #SwimmerEnv
        self.observation_space, self.action_space, self.action_space = self.env.observation_space, self.env.action_space, self.env.action_space
        self.env.seed(seed)
        self.par_ini()

    def par_ini(self, forward_reward_weight=1.0,
                ctrl_cost_weight=1e-3,  # 改 1e-3
                healthy_reward=0.0,  # 改 1.0
                terminate_when_unhealthy=True,
                healthy_z_range=(0.8, 2.0),
                healthy_angle_range=(-1.0, 1.0),
                reset_noise_scale=5e-2,  # 改 5e-3
                exclude_current_positions_from_observation=True):
        self.env_.forward_reward_weight = forward_reward_weight
        # self.env_._ctrl_cost_weight = ctrl_cost_weight
        # self.env_._healthy_reward = healthy_reward
        # self.env_._terminate_when_unhealthy = terminate_when_unhealthy
        # self.env_._healthy_z_range = healthy_z_range
        # self.env_._healthy_angle_range = healthy_angle_range
        # self.env_._reset_noise_scale = reset_noise_scale
        # self.env_._exclude_current_positions_from_observation = exclude_current_positions_from_observation

    def step(self, action):
        action = action.reshape(-1)
        observation, reward, done, info = self.env.step(action)
        return observation.reshape(1, -1), np.array(reward).reshape(1, -1), np.array(done).reshape(1, -1), info

    def reset(self):
        observation = self.env.reset()
        return observation.reshape(1, -1)

    def back_var(self):
        self.state_back = self.state
        self._elapsed_steps_back = self.env._elapsed_steps

    def restore_var(self):
        self.state = self.state_back
        self.env._elapsed_steps = self._elapsed_steps_back

    def penalty(self, action):
        p = np.sum(np.abs(action), 1).reshape(-1, 1) ** 1 * 0.1
        return p

    def set_state_temp(self, qpos, qvel):
        old_state = self.env.sim.get_state()
        new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, old_state.act, old_state.udd_state)
        self.env.sim.set_state(new_state)

    def reset_temp(self):
        # self.env._episode_started_at = time.time()
        self.env._elapsed_steps = 0
        self.env.sim.reset()

        noise_low = -self.env_._reset_noise_scale
        noise_high = self.env_._reset_noise_scale

        qpos = self.env.init_qpos + self.env.np_random.uniform(low=noise_low, high=noise_high, size=self.env.model.nq)
        qvel = self.env.init_qvel + noise_high * self.env.np_random.randn(self.env.model.nv)
        self.set_state_temp(qpos, qvel)

        observation = self.env_._get_obs()
        return observation.reshape(1, -1)

    def reset_mul(self, size=1, lim=None, off_line=None, state=None):
        if off_line is None:
            if lim is not None:
                self.reset_done_counter(size, lim)
            state, position, velocity = None, None, None
            for i in range(size):
                s = self.reset()
                pos_, vel_ = self.env_.sim.data.qpos.reshape(1, -1), self.env_.sim.data.qvel.reshape(1, -1)
                if state is None:
                    state, position, velocity = s, pos_, vel_
                else:
                    state, position, velocity = np.concatenate((state, s), 0), np.concatenate((position, pos_), 0), np.concatenate((velocity, vel_), 0)
            self.state = position, velocity
        else:
            pos, vel = self.state
            for i in off_line:
                state[i, :] = self.reset()
                pos[i, :], vel[i, :] = self.env_.sim.data.qpos, self.env_.sim.data.qvel
            self.state = pos, vel
        return state

    def step_mul(self, action):
        size = action.shape[0]
        next_s, reward, done, position, velocity = None, None, None, None, None
        pos, vel = self.state
        for i in range(size):
            if size > 1:
                self.env.set_state(pos[i], vel[i])  # set_state_temp
            if i > 1:
                self.env._elapsed_steps -= 1
            a = action[i]
            s_, r, d, _ = self.step(a)
            pos_, vel_ = self.env_.sim.data.qpos.reshape(1, -1), self.env_.sim.data.qvel.reshape(1, -1)
            if next_s is None:
                next_s, reward, done, position, velocity = s_, r, d, pos_, vel_
            else:
                next_s, reward, done = np.concatenate((next_s, s_), 0), np.concatenate((reward, r), 0), np.concatenate((done, d), 0)
                position, velocity = np.concatenate((position, pos_), 0), np.concatenate((velocity, vel_), 0)
        self.state = position, velocity
        penalty = self.penalty(action)
        if size > 1:
            done = self.reshape_done(reward, done)
        return next_s, reward, penalty, done

    def reset_done_counter(self, size, lim):
        self.counter = np.zeros((size, 1))
        self.lim = lim

    def reshape_done(self, reward, done):
        con_1 = np.where(reward.reshape(-1) < 0)
        self.counter[con_1, :] += 1
        con_2 = np.where(self.counter.reshape(-1) >= self.lim)
        done[con_2, :] = True
        self.counter[con_2, :] = 0
        return done

    def test(self):
        high = self.env.action_space.high
        low = self.env.action_space.low
        self.reset_mul()
        while True:
            action = self.env_.np_random.uniform(low=low, high=high, size=(1, self.env.action_space.shape[0]))
            s_, r, _, done = self.step_mul(action)
            print('奖励:', r[0][0])
            self.env.render()
            if done:
                self.reset_mul()

    def reshape_done1(self, reward, done):
        con_1 = np.where(reward.reshape(-1) < 0)
        self.counter[con_1, :] += 1
        con_2 = np.where(np.where(self.counter.reshape(-1) >= self.lim))
        con_3 = np.where(reward.reshape(-1) >= 0)
        done[con_2, :] = True
        self.counter[con_2, :] = 0
        self.counter[con_3, :] = 0
        return done

if __name__ == '__main__':
    env = Env()
    env.test()